Skip to content

llama : support Jamba hybrid Transformer-Mamba models #7531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Jul 9, 2025

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented May 25, 2024

This adds support for Jamba (fixes #6372). (https://arxiv.org/abs/2403.19887)

(this has been open for a while, and this description was very different originally (much broader scope), feel free to look at the edit history)

New features

  • Jamba support
    • The first hybrid Transformer+Mamba model in llama.cpp
  • State checkpoints for recurrent models
    • Works best when n_parallel is at least 3 or 4 times the number of actual users
    • Allows backtracking tokens from the end of the last generation without having to reprocess the whole context
      • Very useful with the server example when trimming the stop string
  • Variable GQA (see also OpenELM support #7359)
    • GGUF metadata {model}.attention.head_count_kv can now also be an array of int32_t, one value per layer
    • Layers with 0 kv heads are considered recurrent layers (Mamba, in the case of Jamba).
    • This will make proper support of DeciLM possible

Internal changes

  • move build_mamba_layer functions to a shared parent class between both llm_build_mamba and llm_build_jamba.
  • remove llm_graph_context::build_inp_mem_hybrid
    • Redundant, see next point.
  • remove llm_graph_input_mem_hybrid
    • It's redundant with llm_graph_input_rs and llm_graph_input_attn_kv_unified, and causes unnecessary duplication and overloads of build_rs and build_attn.

Future ideas

  • Recurrent state checkpoints, to allow for backtracking recurrent states
  • Fairly split the available KV cells among active sequences, similarly to RS cells.
  • Handle token shift (and Self-Extend?) when finding a slot.
    • This could help with the fair split of KV cells by freeing cells of sequences which use more than their fair share of cells.

Testing

Example output of jamba-900M-v0.13-KIx2 (click to expand)
$  ./bin/main -m /srv/LLMstash/tmp/jamba-900M.bf16.gguf --temp 0 -e -p "I believe the meaning of life is" --repeat-penalty 1.2 --repeat-last-n 256 -c 16384 -n 256
Log start
main: build = 3003 (0fd13e94)
main: built with gcc (GCC) 13.2.0 for x86_64-unknown-linux-gnu
main: seed  = 1716594011
llama_model_loader: loaded meta data with 26 key-value pairs and 189 tensors from /srv/LLMstash/tmp/jamba-900M.bf16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.name str              = jamba-900M-v0.13-KIx2
llama_model_loader: - kv   2:                          jamba.block_count u32              = 12
llama_model_loader: - kv   3:                       jamba.context_length u32              = 16384
llama_model_loader: - kv   4:                     jamba.embedding_length u32              = 1024
llama_model_loader: - kv   5:                  jamba.feed_forward_length u32              = 4096
llama_model_loader: - kv   6:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv   7:              jamba.attention.head_count_kv arr[i32,12]      = [0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0]
llama_model_loader: - kv   8:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv   9:                       jamba.ssm.inner_size u32              = 2048
llama_model_loader: - kv  10:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  11:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  12:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  13:                         jamba.expert_count u32              = 8
llama_model_loader: - kv  14:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  15:                          general.file_type u32              = 32
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = gpt-2
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,65024]   = ["<EOT>", "<META>", "<META_START>", "...
llama_model_loader: - kv  19:                  tokenizer.ggml.token_type arr[i32,65024]   = [3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  20:                      tokenizer.ggml.merges arr[str,64739]   = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "ĠĠ �...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  25:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  121 tensors
llama_model_loader: - type bf16:   68 tensors
llm_load_vocab: special tokens definition check successful ( 29/65024 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jamba
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 65024
llm_load_print_meta: n_merges         = 64739
llm_load_print_meta: n_ctx_train      = 16384
llm_load_print_meta: n_embd           = 1024
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 12
llm_load_print_meta: n_rot            = 32
llm_load_print_meta: n_embd_head_k    = 32
llm_load_print_meta: n_embd_head_v    = 32
llm_load_print_meta: n_gqa            = 0
llm_load_print_meta: n_embd_k_gqa     = 0
llm_load_print_meta: n_embd_v_gqa     = 0
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 4096
llm_load_print_meta: n_expert         = 8
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 16384
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 2048
llm_load_print_meta: ssm_d_state      = 16
llm_load_print_meta: ssm_dt_rank      = 256
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 887.66 M
llm_load_print_meta: model size       = 1.67 GiB (16.19 BPW) 
llm_load_print_meta: general.name     = jamba-900M-v0.13-KIx2
llm_load_print_meta: BOS token        = 0 '<EOT>'
llm_load_print_meta: EOS token        = 0 '<EOT>'
llm_load_print_meta: UNK token        = 0 '<EOT>'
llm_load_print_meta: PAD token        = 0 '<EOT>'
llm_load_print_meta: LF token         = 133 'Ä'
llm_load_tensors: ggml ctx size =    0.09 MiB
llm_load_tensors:        CPU buffer size =  1713.16 MiB
......................................
llama_new_context_with_model: n_ctx      = 16384
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_cache_init:        CPU cache buf size =    49.34 MiB
llama_new_context_with_model: SSM state size =     1.34 MiB, R (f32):    0.21 MiB, S (f32):    1.12 MiB
llama_new_context_with_model: KV cache size  =    48.00 MiB, K (f16):   24.00 MiB, V (f16):   24.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.25 MiB
llama_new_context_with_model:        CPU compute buffer size =  1062.03 MiB
llama_new_context_with_model: graph nodes  = 621
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 2 / 4 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 256, repeat_penalty = 1.200, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 16384, n_batch = 2048, n_predict = 256, n_keep = 0


<EOT>I believe the meaning of life is not to be found in a single word, but rather as an expression of one's own feelings and thoughts.

The idea that we are all born with our bodies, whether they are human or animal, has been around for centuries. It was believed by some that it was something like a body made up of bones, which were attached to each other at birth. The most common form of this type of bone is called a "bone." This is what makes it so hard to tell if you're alive or dead. In fact, there are many different types of bones, including those that have been used for various purposes such as healing wounds, wounding wounds, etc.

In ancient times, people had a lot of teeth, and these were often very small. They could also be placed on top of their heads, where they would sit down and look at them. These were usually large, round stones, which were sometimes covered with hair. When the skin was removed from the head, the bones became more prominent, and the muscles began to grow larger.

This kind of bone was known as a "bone" because it was made out of two parts: the outermost part (the innermost portion) and the innermost part (the outermost
llama_print_timings:        load time =     252.28 ms
llama_print_timings:      sample time =     303.07 ms /   256 runs   (    1.18 ms per token,   844.68 tokens per second)
llama_print_timings: prompt eval time =     200.72 ms /     8 tokens (   25.09 ms per token,    39.86 tokens per second)
llama_print_timings:        eval time =   12516.79 ms /   255 runs   (   49.09 ms per token,    20.37 tokens per second)
llama_print_timings:       total time =   13213.95 ms /   263 tokens
Log end

@compilade compilade added enhancement New feature or request model Model specific refactoring Refactoring need feedback Testing and feedback with results are needed embeddings embedding related topics python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 25, 2024
@compilade compilade marked this pull request as draft May 25, 2024 03:38
llama.cpp Outdated
Comment on lines 5244 to 5248
switch (hparams.n_layer) {
// TODO: Jamba layers are a bit heterogenous, so naming this is hard.
case 12: // 900M 8x???M
case 32: // 51B 16x?B
default: model.type = e_model::MODEL_UNKNOWN;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what model size type(s) I should give to Jamba models.

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label May 25, 2024
Copy link
Contributor

github-actions bot commented May 25, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 557 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8384.34ms p(95)=20451.68ms fails=, finish reason: stop=510 truncated=47
  • Prompt processing (pp): avg=102.96tk/s p(95)=478.95tk/s
  • Token generation (tg): avg=36.48tk/s p(95)=48.13tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=compilade/refactor-kv-cache commit=fee3c1d740c0e027c81e2f2f3fb48d619857175f

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 306.61, 306.61, 306.61, 306.61, 306.61, 572.5, 572.5, 572.5, 572.5, 572.5, 579.51, 579.51, 579.51, 579.51, 579.51, 601.73, 601.73, 601.73, 601.73, 601.73, 638.34, 638.34, 638.34, 638.34, 638.34, 702.62, 702.62, 702.62, 702.62, 702.62, 704.56, 704.56, 704.56, 704.56, 704.56, 718.91, 718.91, 718.91, 718.91, 718.91, 723.54, 723.54, 723.54, 723.54, 723.54, 739.59, 739.59, 739.59, 739.59, 739.59, 771.46, 771.46, 771.46, 771.46, 771.46, 802.48, 802.48, 802.48, 802.48, 802.48, 815.12, 815.12, 815.12, 815.12, 815.12, 804.65, 804.65, 804.65, 804.65, 804.65, 797.38, 797.38, 797.38, 797.38, 797.38, 800.86, 800.86, 800.86, 800.86, 800.86, 805.61, 805.61, 805.61, 805.61, 805.61, 803.64, 803.64, 803.64, 803.64, 803.64, 824.04, 824.04, 824.04, 824.04, 824.04, 823.3, 823.3, 823.3, 823.3, 823.3, 830.32, 830.32, 830.32, 830.32, 830.32, 832.47, 832.47, 832.47, 832.47, 832.47, 846.38, 846.38, 846.38, 846.38, 846.38, 842.07, 842.07, 842.07, 842.07, 842.07, 844.76, 844.76, 844.76, 844.76, 844.76, 861.96, 861.96, 861.96, 861.96, 861.96, 855.54, 855.54, 855.54, 855.54, 855.54, 854.58, 854.58, 854.58, 854.58, 854.58, 856.84, 856.84, 856.84, 856.84, 856.84, 860.17, 860.17, 860.17, 860.17, 860.17, 858.21, 858.21, 858.21, 858.21, 858.21, 861.33, 861.33, 861.33, 861.33, 861.33, 871.29, 871.29, 871.29, 871.29, 871.29, 847.29, 847.29, 847.29, 847.29, 847.29, 832.73, 832.73, 832.73, 832.73, 832.73, 831.59, 831.59, 831.59, 831.59, 831.59, 831.76, 831.76, 831.76, 831.76, 831.76, 835.52, 835.52, 835.52, 835.52, 835.52, 836.15, 836.15, 836.15, 836.15, 836.15, 836.37, 836.37, 836.37, 836.37, 836.37, 817.57, 817.57, 817.57, 817.57, 817.57, 820.16, 820.16, 820.16, 820.16, 820.16, 820.49, 820.49, 820.49, 820.49, 820.49, 820.0, 820.0, 820.0, 820.0, 820.0, 817.08, 817.08, 817.08, 817.08, 817.08, 820.83, 820.83, 820.83, 820.83, 820.83, 823.82, 823.82, 823.82, 823.82, 823.82, 823.03, 823.03, 823.03, 823.03, 823.03, 827.7, 827.7, 827.7, 827.7, 827.7, 826.96, 826.96, 826.96, 826.96, 826.96, 833.12, 833.12, 833.12, 833.12, 833.12, 832.75, 832.75, 832.75, 832.75, 832.75, 832.65, 832.65, 832.65, 832.65, 832.65, 826.23, 826.23, 826.23, 826.23, 826.23, 827.38, 827.38, 827.38, 827.38, 827.38, 827.43, 827.43, 827.43, 827.43, 827.43, 827.46, 827.46, 827.46, 827.46, 827.46, 825.87, 825.87, 825.87, 825.87, 825.87, 828.84, 828.84, 828.84, 828.84, 828.84, 829.05, 829.05, 829.05, 829.05, 829.05, 829.15, 829.15, 829.15]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.1, 42.1, 42.1, 42.1, 42.1, 30.42, 30.42, 30.42, 30.42, 30.42, 28.2, 28.2, 28.2, 28.2, 28.2, 28.69, 28.69, 28.69, 28.69, 28.69, 29.63, 29.63, 29.63, 29.63, 29.63, 30.55, 30.55, 30.55, 30.55, 30.55, 32.02, 32.02, 32.02, 32.02, 32.02, 32.76, 32.76, 32.76, 32.76, 32.76, 33.41, 33.41, 33.41, 33.41, 33.41, 33.56, 33.56, 33.56, 33.56, 33.56, 34.05, 34.05, 34.05, 34.05, 34.05, 33.99, 33.99, 33.99, 33.99, 33.99, 33.35, 33.35, 33.35, 33.35, 33.35, 33.38, 33.38, 33.38, 33.38, 33.38, 32.25, 32.25, 32.25, 32.25, 32.25, 31.71, 31.71, 31.71, 31.71, 31.71, 30.36, 30.36, 30.36, 30.36, 30.36, 30.81, 30.81, 30.81, 30.81, 30.81, 30.82, 30.82, 30.82, 30.82, 30.82, 30.39, 30.39, 30.39, 30.39, 30.39, 30.41, 30.41, 30.41, 30.41, 30.41, 30.5, 30.5, 30.5, 30.5, 30.5, 30.85, 30.85, 30.85, 30.85, 30.85, 30.97, 30.97, 30.97, 30.97, 30.97, 31.24, 31.24, 31.24, 31.24, 31.24, 31.45, 31.45, 31.45, 31.45, 31.45, 31.23, 31.23, 31.23, 31.23, 31.23, 31.18, 31.18, 31.18, 31.18, 31.18, 31.36, 31.36, 31.36, 31.36, 31.36, 31.43, 31.43, 31.43, 31.43, 31.43, 31.63, 31.63, 31.63, 31.63, 31.63, 31.71, 31.71, 31.71, 31.71, 31.71, 31.78, 31.78, 31.78, 31.78, 31.78, 31.61, 31.61, 31.61, 31.61, 31.61, 31.48, 31.48, 31.48, 31.48, 31.48, 31.35, 31.35, 31.35, 31.35, 31.35, 31.43, 31.43, 31.43, 31.43, 31.43, 31.54, 31.54, 31.54, 31.54, 31.54, 31.71, 31.71, 31.71, 31.71, 31.71, 31.79, 31.79, 31.79, 31.79, 31.79, 31.85, 31.85, 31.85, 31.85, 31.85, 31.71, 31.71, 31.71, 31.71, 31.71, 31.42, 31.42, 31.42, 31.42, 31.42, 31.06, 31.06, 31.06, 31.06, 31.06, 29.65, 29.65, 29.65, 29.65, 29.65, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.4, 29.4, 29.4, 29.4, 29.4, 29.46, 29.46, 29.46, 29.46, 29.46, 29.58, 29.58, 29.58, 29.58, 29.58, 29.61, 29.61, 29.61, 29.61, 29.61, 29.57, 29.57, 29.57, 29.57, 29.57, 29.58, 29.58, 29.58, 29.58, 29.58, 29.45, 29.45, 29.45, 29.45, 29.45, 29.55, 29.55, 29.55, 29.55, 29.55, 29.69, 29.69, 29.69, 29.69, 29.69, 29.83, 29.83, 29.83, 29.83, 29.83, 29.9, 29.9, 29.9, 29.9, 29.9, 29.96, 29.96, 29.96, 29.96, 29.96, 29.97, 29.97, 29.97, 29.97, 29.97, 30.03, 30.03, 30.03]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.14, 0.14, 0.14, 0.14, 0.37, 0.37, 0.37, 0.37, 0.37, 0.25, 0.25, 0.25, 0.25, 0.25, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.25, 0.25, 0.25, 0.25, 0.25, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.41, 0.41, 0.41, 0.41, 0.41, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.23, 0.23, 0.23, 0.23, 0.23, 0.2, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.32, 0.32, 0.32, 0.32, 0.32, 0.21, 0.21, 0.21, 0.21, 0.21, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.28, 0.28, 0.28, 0.28, 0.28, 0.3, 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.45, 0.45, 0.45, 0.45, 0.45, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.64, 0.64, 0.64, 0.64, 0.64, 0.36, 0.36, 0.36, 0.36, 0.36, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.29, 0.29, 0.29, 0.29, 0.29, 0.27, 0.27, 0.27, 0.27, 0.27, 0.24, 0.24, 0.24, 0.24, 0.24, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0]
                    
Loading

@arch-btw
Copy link
Contributor

Great job! Works for me too, it's very fast. There were some warnings during compilation, but nothing major.

<EOT>Hello!

I'll get a new one for you and I think this is going to be really cool, so good. And I'm sure there's lots of ways in which [...]

llama_print_timings:        load time =     286.42 ms
llama_print_timings:      sample time =     155.94 ms /   256 runs   (    0.61 ms per token,  1641.63 tokens per second)
llama_print_timings: prompt eval time =      70.77 ms /     3 tokens (   23.59 ms per token,    42.39 tokens per second)
llama_print_timings:        eval time =    9368.54 ms /   255 runs   (   36.74 ms per token,    27.22 tokens per second)
llama_print_timings:       total time =    9686.16 ms /   258 tokens

@TechxGenus
Copy link

Amazing work!
I initially tested Jamba-v0.1 on a machine with 500G RAM and it worked great!

./main -m ./Jamba-v0.1-hf-00001-of-00024.gguf -n 120 --prompt "def max(arr):" --temp 0
Log start
main: build = 3006 (fc59407e)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1716710334
llama_model_loader: additional 23 GGUFs metadata loaded.
llama_model_loader: loaded meta data with 31 key-value pairs and 531 tensors from ./Jamba-v0.1-hf-00001-of-00024.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.name str              = Jamba-v0.1-hf
llama_model_loader: - kv   2:                          jamba.block_count u32              = 32
llama_model_loader: - kv   3:                       jamba.context_length u32              = 262144
llama_model_loader: - kv   4:                     jamba.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  jamba.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv   7:              jamba.attention.head_count_kv arr[i32,32]      = [0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, ...
llama_model_loader: - kv   8:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv   9:                       jamba.ssm.inner_size u32              = 8192
llama_model_loader: - kv  10:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  11:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  12:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  13:                         jamba.expert_count u32              = 16
llama_model_loader: - kv  14:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  15:                          general.file_type u32              = 32
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,65536]   = ["<|pad|>", "<|startoftext|>", "<|end...
llama_model_loader: - kv  19:                      tokenizer.ggml.scores arr[f32,65536]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,65536]   = [3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  24:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - kv  28:                                   split.no u16              = 0
llama_model_loader: - kv  29:                                split.count u16              = 24
llama_model_loader: - kv  30:                        split.tensors.count i32              = 531
llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type bf16:  170 tensors
llm_load_vocab: special tokens definition check successful ( 1799/65536 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jamba
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 65536
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 262144
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 0
llm_load_print_meta: n_embd_k_gqa     = 0
llm_load_print_meta: n_embd_v_gqa     = 0
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 16
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 262144
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 8192
llm_load_print_meta: ssm_d_state      = 16
llm_load_print_meta: ssm_dt_rank      = 256
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 51.57 B
llm_load_print_meta: model size       = 96.30 GiB (16.04 BPW) 
llm_load_print_meta: general.name     = Jamba-v0.1-hf
llm_load_print_meta: BOS token        = 1 '<|startoftext|>'
llm_load_print_meta: EOS token        = 2 '<|endoftext|>'
llm_load_print_meta: UNK token        = 3 '<|unk|>'
llm_load_print_meta: PAD token        = 0 '<|pad|>'
llm_load_print_meta: LF token         = 1554 '<0x0A>'
llm_load_print_meta: EOT token        = 2 '<|endoftext|>'
llm_load_tensors: ggml ctx size =    0.24 MiB
llm_load_tensors:        CPU buffer size =  4851.72 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  5095.47 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  3584.03 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4851.77 MiB
llm_load_tensors:        CPU buffer size =  3584.03 MiB
..............................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_cache_init:        CPU cache buf size =    24.63 MiB
llama_new_context_with_model: SSM state size =    16.62 MiB, R (f32):    2.62 MiB, S (f32):   14.00 MiB
llama_new_context_with_model: KV cache size  =     8.00 MiB, K (f16):    4.00 MiB, V (f16):    4.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.25 MiB
llama_new_context_with_model:        CPU compute buffer size =   145.10 MiB
llama_new_context_with_model: graph nodes  = 1730
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 32 / 64 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = 120, n_keep = 1


<|startoftext|> def max(arr):
    return max(arr)


def min(arr):
    return min(arr)


def mean(arr):
    return sum(arr) / len(arr)


def median(arr):
    arr.sort()
    if len(arr) % 2 == 0:
        return (arr[len(arr) // 2] + arr[len(arr) // 2 - 1]) / 2
    else:
        return arr[len(arr) // 2]


llama_print_timings:        load time =   82494.54 ms
llama_print_timings:      sample time =       9.61 ms /   120 runs   (    0.08 ms per token, 12490.89 tokens per second)
llama_print_timings: prompt eval time =     666.33 ms /     6 tokens (  111.06 ms per token,     9.00 tokens per second)
llama_print_timings:        eval time =   27656.31 ms /   119 runs   (  232.41 ms per token,     4.30 tokens per second)
llama_print_timings:       total time =   28862.18 ms /   125 tokens
Log end

ggml.c Outdated
Comment on lines 16264 to 16267
if (n_rs > 1) {
// multiple sequences means it's hard to know when it's the first time a state is read,
// so copy them all over to the destination, just to be sure.
for (int i3 = 0; i3 < n_kv; ++i3) {
for (int i3 = 0; i3 < n_rs; ++i3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at adding the missing Metal kernels for SSM_CONV and SSM_SCAN. I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Also, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Yes, this is definitely possible. I'll find a way to extract the copies outside.

if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

For SSM_SCAN, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators like SOFT_PLUS and EXP would be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.

For simplifying SSM_CONV, I don't think ggml_conv supports working on independent 1D rolling windows with varying sequence lengths.

When working on a single sequence, though, it's quite simple to do the equivalent of ggml_ssm_conv with a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):

https://github.com/ggerganov/llama.cpp/blob/64fbce052373faf07a36b599528f8fe1cb1d62fb/llama.cpp#L6973-L6982

Setting nb[2] to the element size makes the view self-overlapping.

But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.

Copy link
Member

@ggerganov ggerganov May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

The main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).

Just throwing some thoughts that I have so far - will continue looking at the PR in the next days

Edit: I was writing this comment before I saw you posted - will take a look tomorrow

Copy link
Collaborator Author

@compilade compilade May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

Yes, this would be doable, but would make the number of compute graph nodes scale with the number of sequences. (EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

The recurrent steps are simpler for ubatches with sequence lengths of 1, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.

But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of ggml_set_rows? Not sure about the transposed V cache, though.

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's split when making ubatches, then the number of compute graph nodes can stay constant

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

I'm currently working on a big refactor of how Mamba (and Jamba) works to make all sequences of a sub-batch be of the same length (initially only for models with recurrent states), and to make recurrent state slots contiguous, with the goal of simplifying the SSM operations (and removing GGML_OP_SSM_CONV), so that GPU support will be much easier to implement after that.

Looking forward to this!

Copy link
Collaborator Author

@compilade compilade May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.

Let me illustrate.

Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.

0: ################
1: #######
2: #
3: #

Splitting that into equal-length sequences would make 3 ubatches, like so:

0: #
1: #
2: #
3: #
0: ######
1: ######
0: #########

Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.

But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍

@gabe-l-hart
Copy link
Contributor

@compilade Is there any outstanding testing before merging this that I can help with?

@compilade
Copy link
Collaborator Author

@gabe-l-hart There's Jamba 1.7 which was released recently, and I was meaning to test at least Jamba-Mini-1.7 to see if it works (including with --jinja for the chat template (e.g. with -cnv with llama-cli)).

Since my network is quite slow, the only way I can test it in a reasonable amount of time would be in a remote instance, but I didn't get around to do that yet (I might today).

@gabe-l-hart
Copy link
Contributor

Alright, let's see how fast the downloads are on my CUDA box ⌛

@gabe-l-hart
Copy link
Contributor

gabe-l-hart commented Jul 8, 2025

Things are looking good!

(NOTE: Built off of GraniteFour which includes this branch as of 07c252f)

Setup

python convert_hf_to_gguf.py ~/models/ai21labs/AI21-Jamba-Mini-1.7/
./build/bin/llama-quantize /home/ghart/models/ai21labs/AI21-Jamba-Mini-1.7/AI21-Jamba-Mini-1.7-F16.gguf Q4_K_M
./build/bin/llama-cli -m ~/models/ai21labs/AI21-Jamba-Mini-1.7/ggml-model-Q4_K_M.gguf -p "You are a helpful AI assistant" --jinja

Results

chat.log
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 2 CUDA devices:
  Device 0: NVIDIA L40S, compute capability 8.9, VMM: yes
  Device 1: NVIDIA L40S, compute capability 8.9, VMM: yes
build: 5808 (a9dcc845) with cc (GCC) 14.2.1 20250110 (Red Hat 14.2.1-7) for x86_64-redhat-linux
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_load_from_file_impl: using device CUDA0 (NVIDIA L40S) - 45025 MiB free
llama_model_load_from_file_impl: using device CUDA1 (NVIDIA L40S) - 45025 MiB free
llama_model_loader: loaded meta data with 37 key-value pairs and 531 tensors from /home/ghart/models/ai21labs/AI21-Jamba-Mini-1.7/ggml-model-Q4_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = AI21 Jamba Mini 1.7
llama_model_loader: - kv   3:                            general.version str              = 1.7
llama_model_loader: - kv   4:                           general.basename str              = AI21-Jamba
llama_model_loader: - kv   5:                         general.size_label str              = Mini
llama_model_loader: - kv   6:                            general.license str              = other
llama_model_loader: - kv   7:                       general.license.name str              = jamba-open-model-license
llama_model_loader: - kv   8:                       general.license.link str              = https://www.ai21.com/jamba-open-model...
llama_model_loader: - kv   9:                          jamba.block_count u32              = 32
llama_model_loader: - kv  10:                       jamba.context_length u32              = 262144
llama_model_loader: - kv  11:                     jamba.embedding_length u32              = 4096
llama_model_loader: - kv  12:                  jamba.feed_forward_length u32              = 14336
llama_model_loader: - kv  13:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv  14:              jamba.attention.head_count_kv arr[i32,32]      = [0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, ...
llama_model_loader: - kv  15:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv  16:                       jamba.ssm.inner_size u32              = 8192
llama_model_loader: - kv  17:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  18:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  19:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  20:                         jamba.expert_count u32              = 16
llama_model_loader: - kv  21:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  22:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  23:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  24:                      tokenizer.ggml.tokens arr[str,65536]   = ["<|pad|>", "<|startoftext|>", "<|end...
llama_model_loader: - kv  25:                      tokenizer.ggml.scores arr[f32,65536]   = [-1000.000000, -1000.000000, -1000.00...
llama_model_loader: - kv  26:                  tokenizer.ggml.token_type arr[i32,65536]   = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  27:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  28:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  29:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  30:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  31:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  32:               tokenizer.ggml.add_sep_token bool             = false
llama_model_loader: - kv  33:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  34:                    tokenizer.chat_template str              = {# Variables #}\n{% set ns = namespace...
llama_model_loader: - kv  35:               general.quantization_version u32              = 2
llama_model_loader: - kv  36:                          general.file_type u32              = 15
llama_model_loader: - type  f32:  305 tensors
llama_model_loader: - type q4_K:  207 tensors
llama_model_loader: - type q6_K:   19 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = Q4_K - Medium
print_info: file size   = 29.02 GiB (4.83 BPW) 
load: special tokens cache size = 1543
load: token to piece cache size = 0.4034 MB
print_info: arch             = jamba
print_info: vocab_only       = 0
print_info: n_ctx_train      = 262144
print_info: n_embd           = 4096
print_info: n_layer          = 32
print_info: n_head           = 32
print_info: n_head_kv        = [0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0]
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = [0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]
print_info: n_embd_k_gqa     = [0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0]
print_info: n_embd_v_gqa     = [0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0, 0, 0, 0, 0, 1024, 0, 0, 0]
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 14336
print_info: n_expert         = 16
print_info: n_expert_used    = 2
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = -1
print_info: rope scaling     = linear
print_info: freq_base_train  = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 262144
print_info: rope_finetuned   = unknown
print_info: ssm_d_conv       = 4
print_info: ssm_d_inner      = 8192
print_info: ssm_d_state      = 16
print_info: ssm_dt_rank      = 256
print_info: ssm_n_group      = 0
print_info: ssm_dt_b_c_rms   = 0
print_info: model type       = ?B
print_info: model params     = 51.57 B
print_info: general.name     = AI21 Jamba Mini 1.7
print_info: vocab type       = SPM
print_info: n_vocab          = 65536
print_info: n_merges         = 0
print_info: BOS token        = 1 '<|startoftext|>'
print_info: EOS token        = 2 '<|endoftext|>'
print_info: EOT token        = 2 '<|endoftext|>'
print_info: UNK token        = 3 '<|unk|>'
print_info: PAD token        = 0 '<|pad|>'
print_info: LF token         = 1554 '<0x0A>'
print_info: EOG token        = 2 '<|endoftext|>'
print_info: max token length = 96
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors: offloading 0 repeating layers to GPU
load_tensors: offloaded 0/33 layers to GPU
load_tensors:   CPU_Mapped model buffer size = 29717.67 MiB
............................................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_per_seq = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = 0
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_per_seq (4096) < n_ctx_train (262144) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     0.25 MiB
llama_kv_cache_unified:        CPU KV buffer size =    64.00 MiB
llama_kv_cache_unified: size =   64.00 MiB (  4096 cells,   4 layers,  1 seqs), K (f16):   32.00 MiB, V (f16):   32.00 MiB
llama_kv_cache_unified: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility
llama_memory_recurrent: mem_size = 1, n_seq_max = 1, type_r = 'f32', type_s = 'f32', n_layer = 32
llama_memory_recurrent:        CPU KV buffer size =    16.62 MiB
llama_memory_recurrent: KV self size  =   16.62 MiB, R (f32):    2.62 MiB, S (f32):   14.00 MiB
llama_context:      CUDA0 compute buffer size =   839.22 MiB
llama_context:  CUDA_Host compute buffer size =   276.00 MiB
llama_context: graph nodes  = 2166
llama_context: graph splits = 565 (with bs=512), 1 (with bs=1)
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 24
main: chat template is available, enabling conversation mode (disable it with -no-cnv)
*** User-specified prompt will pre-start conversation, did you mean to set --system-prompt (-sys) instead?
main: chat template example:
<|bom|><|system|> You are a helpful assistant<|eom|><|bom|><|user|> Hello<|eom|><|bom|><|assistant|> Hi there<|eom|><|bom|><|user|> How are you?<|eom|><|bom|><|assistant|>

system_info: n_threads = 24 (n_threads_batch = 24) / 48 | CUDA : ARCHS = 890 | USE_GRAPHS = 1 | PEER_MAX_BATCH_SIZE = 128 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | AMX_INT8 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 

main: interactive mode on.
sampler seed: 3821138688
sampler params: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
	top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-n-sigma -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist 
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 0

== Running in interactive mode. ==
 - Press Ctrl+C to interject at any time.
 - Press Return to return control to the AI.
 - To return control without starting a new line, end your input with '/'.
 - If you want to submit another line, end your input with '\'.
 - Not using system message. To change it, set a different value via -sys PROMPT

    You are a helpful AI assistant Absolutely! How can I assist you today?

> Tell me a story about a developer and their dog
 Sure! Here's a heartwarming story about a developer and their dog:


---

**The Developer and Their Dog**

In a cozy little apartment in the bustling city of San Francisco, lived a young developer named Alex. Alex was passionate about coding and spent most of their days and nights working on various projects, always eager to learn new technologies and improve their skills. Their apartment was a haven of technology, filled with the latest gadgets and a comfortable desk where they spent countless hours in front of a glowing screen.

One chilly evening, after a particularly long day of debugging and refactoring, Alex decided to take a break. They stood up from their desk, stretched their arms, and looked out the window. The city lights twinkled like stars in the night sky, and a soft, cool breeze wafted in through the open window.

"Maybe it's time for a walk," Alex thought to themselves, and with a smile, they called out, "Buddy! It's time for a walk!"

A moment later, a wagging tail and a pair of eager eyes appeared from the bedroom. Buddy, Alex's loyal dog, bounded into the living room, his furry coat shining in the dim light. Buddy was a medium-sized mixed-breed dog, with a patch of black fur on his back and a mischievous sparkle in his eyes.

"Let's go, Buddy!" Alex exclaimed, grabbing the dog's leash from the hook by the door.

Together, they stepped out into the crisp night air. The city was alive with the sounds of traffic and the chatter of people, but Alex and Buddy walked in silence, enjoying the peacefulness of the evening. They strolled past the tall buildings, their reflections shimmering in the glass windows, and down a quiet street lined with trees.

As they walked, Alex thought about the challenges they faced at work that day. Debugging a particularly tricky bug had been frustrating, but they had finally managed to solve it. The sense of accomplishment was immense, and Alex felt a wave of gratitude for the challenges that pushed them to grow.

Buddy, sensing Alex's calm and content mood, wagged his tail and looked up at them with a smile. Alex reached down and gave him a gentle pat on the head.

"You're the best, Buddy," Alex said softly.

Buddy barked softly in response, as if agreeing.

They continued their walk, eventually making their way to a small park nearby. Alex unclipped Buddy's leash, and the dog ran off to explore, sniffing the grass and chasing after squirrels. Alex sat down on a bench, watching Buddy with a smile on their face.

"Life is good, isn't it, Buddy?" Alex mused.

Buddy looked back at them, his eyes sparkling with the same contentment Alex felt.

After a while, Alex stood up and called out, "Time to go home, Buddy!"

Buddy bounded back, his tail wagging happily, and they walked back to the apartment. Once inside, Alex prepared a cozy spot for Buddy with his favorite blanket, and they settled down for the night.

"Goodnight, Buddy," Alex said, giving Buddy a final pat.

Buddy let out a soft sigh, his body relaxed and content, and soon they were both fast asleep, dreaming of coding adventures and endless possibilities.


---

The next morning, as Alex sipped their coffee and prepared to dive into another day of coding, they couldn't help but smile at the thought of their loyal companion, Buddy. Life was indeed good, and with Buddy by their side, Alex felt ready to conquer any challenge that came their way.

> EOF by user


llama_perf_sampler_print:    sampling time =      58.34 ms /   816 runs   (    0.07 ms per token, 13986.49 tokens per second)
llama_perf_context_print:        load time =    1529.39 ms
llama_perf_context_print: prompt eval time =     988.11 ms /    34 tokens (   29.06 ms per token,    34.41 tokens per second)
llama_perf_context_print:        eval time =   57717.84 ms /   809 runs   (   71.34 ms per token,    14.02 tokens per second)
llama_perf_context_print:       total time =   86718.96 ms /   843 tokens

@gabe-l-hart
Copy link
Contributor

I updated to the latest tip (f716358) and continue to see good results (yet another story about Alex the developer, this time with Luna as the trusty canine sidekick)

@compilade compilade added the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label Jul 8, 2025
@@ -10009,16 +10056,15 @@ struct llm_build_mamba : public llm_graph_context {

// TODO: skip computing output earlier for unused tokens

y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d));
y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z)));
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I didn't know about this operator. I see it was added recently (#14158). Seems useful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, and whenever Vulkan supports non-contiguous input we can remove the ggml_cont.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I've added this ggml_cont to avoid an assertion error in the metal backend

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the metal implementation of silu requires it to be contiguous, however the swiglu implementation does not. :)

Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's gooo :)

@gabe-l-hart
Copy link
Contributor

@compilade I merged the latest master with Falcon H1 into #13550 and have tested that the following models convert, quantize, and run correctly:

In case you haven't made the same changes locally, the merge resolution commit is 1334c7 followed by f8b81c to fix the hybrid input use.

@gabe-l-hart
Copy link
Contributor

Jinx! I'll merge yours into mine again. The one outstanding question I have in this is whether we should make the order of Falcon H1 consistent between the various enum declarations and usages. On my merge resolution, I moved it so that Falcon H1 always comes directly after Falcon.

@compilade
Copy link
Collaborator Author

compilade commented Jul 9, 2025

@gabe-l-hart Thanks. I've also merged the changes here (apparently we were doing that at the same time :)), and I've tested inference with fresh conversions of the following models:

The one outstanding question I have in this is whether we should make the order of Falcon H1 consistent between the various enum declarations and usages.

Hmm, you're right that usually, the order should be consistent. There might be some order dependencies between the structs on the C++ side, though (with the shared mamba2 layer builder).

I might tend toward this being fixed in its own PR (and then verifying the changes only move lines with git log -p --color-moved and maybe test a few Falcon-H1 models again).

@gabe-l-hart
Copy link
Contributor

Makes sense. I'll avoid shuffling things on GR4 and we can defer that to another PR

(sidebar, I love learning new git tricks!)

Some of the tensor names are common with Llama4
@compilade
Copy link
Collaborator Author

Ok, this time I think it's ready. Will merge after the CI passes.

@compilade compilade merged commit 4a5686d into master Jul 9, 2025
52 checks passed
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jul 9, 2025
* origin/master:
llama : support Jamba hybrid Transformer-Mamba models (ggml-org#7531)
ggml : add ggml_scale_bias (ggml-org#14417)
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Jul 10, 2025
* wip: llama : separate recurrent states from the KV cache

This will be necessary to support Jamba
(and other recurrent models mixed with Attention).

Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.

* llama : use std::find for seq_nodes in llama_rs_cache

* llama : state checkpoints for recurrent models

* llama : correctly handle more edge cases for the rs cache

* llama : rename many llama_kv_cache_* functions

* llama : remove useless return value for some llama_cache_* functions

* llama : rethink recurrent state cell counts

* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot

* llama : support Jamba

* llama : fix BERT inference without KV cache

* convert-hf : check for unprocessed Jamba experts

* convert-hf : support Mini-Jamba conversion

* llama : fix Jamba quantization sanity checks

* llama : sequence-length-aware batch splitting

* llama : use equal-sequence-length sub-batches for recurrent models

* ggml : simplify SSM-related operators

* llama : make recurrent state slot allocation contiguous

* llama : adapt internal uses of batches to llama_ubatch

* llama : fix batch split output count for embeddings

* llama : minimize swaps when reordering logits

This reduces overhead when running hellaswag
on thousands of sequences with very small 100k params Mamba models.

* llama : fix edge case finding batch seq_id of split recurrent cell

This otherwise was a problem when running the HellaSwag benchmark
with small batch sizes, making it crash.

* llama : avoid copies for simple batch splits

* ggml : make ggml_ssm_scan not modify its source tensors

* llama : fix shared recurrent tail cell count for small ubatch sizes

Otherwise it was impossible to run the 'parallel' example with '-ub 1'
with a Mamba or Jamba model.

* llama : fix .base() compilation error on Windows

* llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL

* ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors

The implementation already supported it,
and this makes Mamba's conv step slightly faster.

* mamba : fix non-contiguous usage of ggml_silu

* llama : session saving and reloading for hybrid models

* convert_hf : fix Jamba conversion

* llama : fix mixed signedness comparison

* llama : use unused n_embd_k_gqa in k_shift

This also slightly reduces the diff from the master branch

* llama : begin renaming llama_past back to llama_kv_cache

* llama : remove implicit recurrent state rollbacks

* llama : partially apply clang-format style

* convert : fix jamba conv1d shape squeezing

* graph : add back hybrid memory graph input

But this time it contains the sub-cache graph inputs.
This *should* make it easier to handle updating the inputs
when caching the graph (eventually).

* model : add Jamba to Mamba-specific hparams printing

* jamba : remove redundant nullptr initializations

* model : remove unnecessary prefix for tensor loading constants

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* model : use ggml_swiglu_split for Mamba

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* model : make falcon-h1 use shared mamba2 layer builder

* memory : avoid referring to KV in recurrent cache logs

* gguf-py : avoid adding duplicate tensor mappings for Jamba

Some of the tensor names are common with Llama4

---------

Co-authored-by: Sigbjørn Skjæret <[email protected]>
qnixsynapse pushed a commit to menloresearch/llama.cpp that referenced this pull request Jul 10, 2025
* wip: llama : separate recurrent states from the KV cache

This will be necessary to support Jamba
(and other recurrent models mixed with Attention).

Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states.

* llama : use std::find for seq_nodes in llama_rs_cache

* llama : state checkpoints for recurrent models

* llama : correctly handle more edge cases for the rs cache

* llama : rename many llama_kv_cache_* functions

* llama : remove useless return value for some llama_cache_* functions

* llama : rethink recurrent state cell counts

* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot

* llama : support Jamba

* llama : fix BERT inference without KV cache

* convert-hf : check for unprocessed Jamba experts

* convert-hf : support Mini-Jamba conversion

* llama : fix Jamba quantization sanity checks

* llama : sequence-length-aware batch splitting

* llama : use equal-sequence-length sub-batches for recurrent models

* ggml : simplify SSM-related operators

* llama : make recurrent state slot allocation contiguous

* llama : adapt internal uses of batches to llama_ubatch

* llama : fix batch split output count for embeddings

* llama : minimize swaps when reordering logits

This reduces overhead when running hellaswag
on thousands of sequences with very small 100k params Mamba models.

* llama : fix edge case finding batch seq_id of split recurrent cell

This otherwise was a problem when running the HellaSwag benchmark
with small batch sizes, making it crash.

* llama : avoid copies for simple batch splits

* ggml : make ggml_ssm_scan not modify its source tensors

* llama : fix shared recurrent tail cell count for small ubatch sizes

Otherwise it was impossible to run the 'parallel' example with '-ub 1'
with a Mamba or Jamba model.

* llama : fix .base() compilation error on Windows

* llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL

* ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors

The implementation already supported it,
and this makes Mamba's conv step slightly faster.

* mamba : fix non-contiguous usage of ggml_silu

* llama : session saving and reloading for hybrid models

* convert_hf : fix Jamba conversion

* llama : fix mixed signedness comparison

* llama : use unused n_embd_k_gqa in k_shift

This also slightly reduces the diff from the master branch

* llama : begin renaming llama_past back to llama_kv_cache

* llama : remove implicit recurrent state rollbacks

* llama : partially apply clang-format style

* convert : fix jamba conv1d shape squeezing

* graph : add back hybrid memory graph input

But this time it contains the sub-cache graph inputs.
This *should* make it easier to handle updating the inputs
when caching the graph (eventually).

* model : add Jamba to Mamba-specific hparams printing

* jamba : remove redundant nullptr initializations

* model : remove unnecessary prefix for tensor loading constants

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* model : use ggml_swiglu_split for Mamba

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* model : make falcon-h1 use shared mamba2 layer builder

* memory : avoid referring to KV in recurrent cache logs

* gguf-py : avoid adding duplicate tensor mappings for Jamba

Some of the tensor names are common with Llama4

---------

Co-authored-by: Sigbjørn Skjæret <[email protected]>
@ggerganov ggerganov added the hot Something that is hot label Jul 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
android Issues specific to Android embeddings embedding related topics enhancement New feature or request examples ggml changes relating to the ggml tensor library for machine learning hot Something that is hot merge ready indicates that this may be ready to merge soon and is just holding out in case of objections model Model specific need feedback Testing and feedback with results are needed python python script changes refactoring Refactoring Review Complexity : High Generally require indepth knowledge of LLMs or GPUs server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Suport for Jamba JambaForCausalLM